#include "tensorrt_onnx.h"

namespace {

bool isFileExists(const std::string &file_name) {
  std::ifstream fin(file_name);
  if (fin) {
    return true;
  } else {
    std::cout << "The file is not exist: " << file_name << std::endl;
    return false;
  }
  return true;
}

} // namespace

TensorrtOnnxInference::~TensorrtOnnxInference() {
  for (std::size_t i = 0; i < input_indexes_.size(); ++i) {
    const int index = input_indexes_.at(i);
    CUDA_CHECK(cudaFree(buffers_[index]));
  }
  for (std::size_t i = 0; i < output_indexes_.size(); ++i) {
    const int index = output_indexes_.at(i);
    CUDA_CHECK(cudaFree(buffers_[index]));
  }
  cudaStreamDestroy(stream_);
  if (context_) {
    context_->destroy();
  }
  if (engine_) {
    engine_->destroy();
  }

  for (std::size_t i = 0; i < output_buffers_.size(); ++i) {
    if (output_buffers_.at(i)) {
      delete[] output_buffers_.at(i);
      output_buffers_.at(i) = nullptr;
    }
  }
}

bool TensorrtOnnxInference::Init() {
  if (!isFileExists(model_path_)) {
    std::cout << "Model file: " << model_path_ << " is not exist!" << std::endl;
    return false;
  }

  size_t find = model_path_.find(".onnx");
  if (find == std::string::npos) {
    std::cout << "The model file should be onnx format!\n";
    return false;
  }
  const std::string path = model_path_.substr(0, find);
  const std::string engine_path = path + ".engine";

  if (!isFileExists(engine_path)) {
    std::cout << "The engine file " << engine_path
              << " has not been generated, try to generate..." << std::endl;
    engine_ = SerializeToEngineFile(model_path_, engine_path);
    std::cout << "Succeed to generate engine file: " << engine_path
              << std::endl;
  } else {
    std::cout << "Use the exists engine file: " << engine_path << std::endl;
    engine_ = LoadFromEngineFile(engine_path);
  }

  assert(engine_);
  context_ = engine_->createExecutionContext();
  assert(context_ != nullptr);

  AllocateBuffers();
  return true;
}

bool TensorrtOnnxInference::Infer(const std::vector<float *> &input_data,
                                  std::vector<const float *> &output_data) {
  doInference(input_data, output_data);
  return true;
}

nvinfer1::ICudaEngine *
TensorrtOnnxInference::SerializeToEngineFile(const std::string &model_path,
                                             const std::string &engine_path) {
  nvinfer1::IBuilder *builder = nvinfer1::createInferBuilder(gLogger_);
  assert(builder != nullptr);

  const auto explicitBatch =
      1U << static_cast<uint32_t>(
          nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);

  nvinfer1::INetworkDefinition *network =
      builder->createNetworkV2(explicitBatch);

  auto parser = nvonnxparser::createParser(*network, gLogger_);
  assert(parser != nullptr);

  std::cout << "Start parsing model..." << std::endl;
  int verbosity =
      static_cast<int>(nvinfer1::ILogger::Severity::kINTERNAL_ERROR);
  if (!parser->parseFromFile(model_path.c_str(), verbosity)) {
    std::cout << "Failed to parse onnx model: " << model_path << std::endl;
    return nullptr;
  }
  std::cout << "End parsing model" << std::endl;

  auto config = builder->createBuilderConfig();
  config->setMaxWorkspaceSize(1 << 25);

  nvinfer1::ICudaEngine *engine =
      builder->buildEngineWithConfig(*network, *config);
  assert(engine);

  nvinfer1::IHostMemory *trtModelStream = engine->serialize();
  std::stringstream gieModelStream;
  gieModelStream.seekg(0, gieModelStream.beg);
  gieModelStream.write(static_cast<const char *>(trtModelStream->data()),
                       trtModelStream->size());
  std::ofstream outFile;
  outFile.open(engine_path);
  outFile << gieModelStream.rdbuf();
  outFile.close();

  parser->destroy();
  network->destroy();
  builder->destroy();
  config->destroy();

  return engine;
}

nvinfer1::ICudaEngine *
TensorrtOnnxInference::LoadFromEngineFile(const std::string &engine_path) {
  std::cout << "Loading model from engine file...\n";
  assert(isFileExists(engine_path));
  std::stringstream trtModelStream;
  trtModelStream.seekg(0, trtModelStream.beg);
  std::ifstream cache(engine_path);
  assert(cache.good());
  trtModelStream << cache.rdbuf();
  cache.close();

  trtModelStream.seekg(0, std::ios::end);
  const int modelSize = trtModelStream.tellg();
  trtModelStream.seekg(0, std::ios::beg);
  void *modelMem = malloc(modelSize);
  trtModelStream.read((char *)modelMem, modelSize);

  nvinfer1::IRuntime *runtime = nvinfer1::createInferRuntime(gLogger_);
  nvinfer1::ICudaEngine *engine =
      runtime->deserializeCudaEngine(modelMem, modelSize, nullptr);
  free(modelMem);
  runtime->destroy();
  std::cout << "Load complete." << std::endl;

  return engine;
}

void TensorrtOnnxInference::AllocateBuffers() {
  for (int i = 0; i < engine_->getNbBindings(); ++i) {
    if (engine_->bindingIsInput(i)) {
      input_indexes_.push_back(i);
    } else {
      output_indexes_.push_back(i);
    }
  }

  CUDA_CHECK(cudaStreamCreate(&stream_));
  for (std::size_t i = 0; i < input_indexes_.size(); ++i) {
    const int index = input_indexes_.at(i);
    std::cout << "alloc memory for input " << index << std::endl;
    nvinfer1::Dims input_dim = engine_->getBindingDimensions(index);
    int input_size = 1;
    for (int j = 0; j < input_dim.nbDims; ++j) {
      input_size *= input_dim.d[j];
    }
    input_sizes_.push_back(input_size);
    CUDA_CHECK(cudaMalloc(&buffers_[index], 1 * input_size * sizeof(float)));
    model_input_dims_.push_back(std::make_pair(input_dim.d[2], input_dim.d[3]));

    std::cout << "The input size (NCHW) is : ";
    for (int j = 0; j < input_dim.nbDims; ++j) {
      std::cout << input_dim.d[j];
      if (j != input_dim.nbDims - 1) {
        std::cout << " * ";
      } else {
        std::cout << " = ";
      }
    }
    std::cout << input_size << std::endl;
  }

  for (std::size_t i = 0; i < output_indexes_.size(); ++i) {
    const int index = output_indexes_.at(i);
    std::cout << "alloc memory for output " << index << std::endl;
    nvinfer1::Dims output_dim = engine_->getBindingDimensions(index);
    int output_size = 1;
    for (int j = 0; j < output_dim.nbDims; ++j) {
      output_size *= output_dim.d[j];
    }
    output_sizes_.push_back(output_size);
    CUDA_CHECK(cudaMalloc(&buffers_[index], 1 * output_size * sizeof(float)));

    float *buf = new float[output_size]();
    output_buffers_.push_back(buf);

    std::cout << "The output size (NCHW) is: ";
    for (int j = 0; j < output_dim.nbDims; ++j) {
      std::cout << output_dim.d[j];
      if (j != output_dim.nbDims - 1) {
        std::cout << " * ";
      } else {
        std::cout << " = ";
      }
    }
    std::cout << output_size << std::endl;
  }
}

void TensorrtOnnxInference::doInference(
    const std::vector<float *> &input_data,
    std::vector<const float *> &output_data) {
  assert(input_data.size() == input_indexes_.size());

  for (std::size_t i = 0; i < input_data.size(); ++i) {
    const int index = input_indexes_.at(i);
    CUDA_CHECK(cudaMemcpyAsync(buffers_[index], input_data.at(i),
                               1 * input_sizes_.at(i) * sizeof(float),
                               cudaMemcpyHostToDevice, stream_));
  }

  context_->enqueue(1, buffers_, stream_, nullptr);

  output_data.clear();
  for (std::size_t i = 0; i < output_indexes_.size(); ++i) {
    const int index = output_indexes_.at(i);
    CUDA_CHECK(cudaMemcpyAsync(output_buffers_.at(i), buffers_[index],
                               1 * output_sizes_.at(i) * sizeof(float),
                               cudaMemcpyDeviceToHost, stream_));
    const float *output_buffer = output_buffers_.at(i);
    output_data.push_back(output_buffer);
  }

  cudaStreamSynchronize(stream_);
}

std::pair<int, int>
TensorrtOnnxInference::GetModelInputDims(const int index) const {
  return model_input_dims_.at(index);
}
